{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torch.utils.data as Data\n",
    "torch.manual_seed(8) # for reproduce\n",
    "\n",
    "import time\n",
    "import numpy as np\n",
    "import gc\n",
    "import sys\n",
    "sys.setrecursionlimit(50000)\n",
    "import pickle\n",
    "torch.backends.cudnn.benchmark = True\n",
    "torch.set_default_tensor_type('torch.cuda.FloatTensor')\n",
    "from tensorboardX import SummaryWriter\n",
    "torch.nn.Module.dump_patches = True\n",
    "import copy\n",
    "import pandas as pd\n",
    "#then import my own modules\n",
    "from AttentiveFP import Fingerprint, Fingerprint_viz, save_smiles_dicts, get_smiles_dicts, get_smiles_array, moltosvg_highlight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.metrics import matthews_corrcoef\n",
    "from sklearn.metrics import recall_score\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import r2_score\n",
    "from sklearn.metrics import mean_squared_error\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "from sklearn.metrics import precision_score\n",
    "from sklearn.metrics import precision_recall_curve\n",
    "from sklearn.metrics import auc\n",
    "from sklearn.metrics import f1_score\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from rdkit.Chem import rdMolDescriptors, MolSurf\n",
    "# from rdkit.Chem.Draw import SimilarityMaps\n",
    "from rdkit import Chem\n",
    "# from rdkit.Chem import AllChem\n",
    "from rdkit.Chem import QED\n",
    "%matplotlib inline\n",
    "from numpy.polynomial.polynomial import polyfit\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import matplotlib\n",
    "from IPython.display import SVG, display\n",
    "import seaborn as sns; sns.set(color_codes=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of all smiles:  1484\n",
      "not successfully processed smiles:  [NH4][Pt]([NH4])(Cl)Cl\n",
      "not successfully processed smiles:  c1ccc(cc1)n2c(=O)c(c(=O)n2c3ccccc3)CCS(=O)c4ccccc4\n",
      "not successfully processed smiles:  Cc1cc2c(cc1C)N3C=N2[Co+]456(N7=C8[C@H](C(C7=CC9=N4C(=C(C1=N5[C@@]([C@@H]2N6C(=C8C)[C@@]([C@H]2CC(=O)N)(CCC(=O)NC[C@H](OP(=O)(O[C@@H]2[C@H](O[C@H]3[C@@H]2O)CO)[O-])C)C)([C@@]([C@@H]1CCC(=O)N)(C)CC(=O)N)C)C)[C@@]([C@@H]9CCC(=O)N)(C)CC(=O)N)(C)C)CCC(=O)N)O\n",
      "not successfully processed smiles:  Cc1cc2c(cc1C)N3C=N2[Co]456(N7=C8[C@H](C(C7=CC9=N4C(=C(C1=N5[C@@]([C@@H]2N6C(=C8C)[C@@]([C@H]2CC(=O)N)(CCC(=O)NC[C@H](OP(=O)(O[C@@H]2[C@H](O[C@H]3[C@@H]2O)CO)O)C)C)([C@@]([C@@H]1CCC(=O)N)(C)CC(=O)N)C)C)[C@@]([C@@H]9CCC(=O)N)(C)CC(=O)N)(C)C)CCC(=O)N)C#N\n",
      "not successfully processed smiles:  CCCCc1c(=O)n(n(c1=O)c2ccc(cc2)O)c3ccccc3\n",
      "not successfully processed smiles:  CCCCc1c(=O)n(n(c1=O)c2ccccc2)c3ccccc3\n",
      "number of successfully processed smiles:  1478\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAU8AAAC/CAYAAAB+KF5fAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAF2lJREFUeJzt3X9QFOf9B/A38iNwOQ4hoamKigclgQgoIWBVJGBSijod6sQpDiI3YaBO9SRiYseRakKbRCfGTvR6IjaDKTrSH9qmUUJsCQG0NkyHMMkkaiVXCVBrjAJ3B8cPj/v+kXGT81CWJ8f94Pt+/cU9+9ze5xb37e7z7C4+NpvNBiIimpBp7i6AiMgbMTyJiAQwPImIBDA8iYgEMDyJiAQwPImIBPi5u4CJuH7dJLtvaKgCPT0Dk1jN5GL97sX63ctT6g8PD77rsil75Onn5+vuEr4V1u9erN+9vKH+KRueRESTieFJRCSA4UlEJIDhSUQkgOFJRCTAqy5Vmireb+set8+apx5xQSVEJIpHnkREAhieREQCGJ5ERAIYnkREAhieREQCGJ5ERAIYnkREAsYNz7q6OuTm5iI1NRXx8fHIysqCXq/H8PCw1Mdms6GiogLp6elISEhAXl4eLly44LCu9vZ2FBQUIDExEUuXLsXrr78Oq9Xq3G9EROQC414k39vbi9TUVBQWFiI4OBgfffQRdDodvvzyS+zcuRMAUFlZCb1ej23btkGtVqOqqgoajQanTp1CeHg4AKCvrw8ajQbR0dHQ6/X4/PPPsWfPHoyOjmLLli2T+y2JiJxs3PDMzc21e71o0SL09/fj2LFj+MUvfoHh4WFUVlaiuLgY69atAwAsWLAAmZmZOHr0qBSMNTU1GBoagk6ng1KpxJIlS2A2m6HT6VBUVASlUjkJX4+IaHIIjXlOnz4dIyMjAIDW1laYzWZkZ2dLyxUKBTIyMtDc3Cy1NTU1YenSpXYhuXLlSgwODqKlpUW0fiIit5AdnlarFRaLBf/6179QXV2NtWvXwsfHBwaDAb6+voiMjLTrHxUVBYPBIL02GAxQq9V2fWbOnImgoCC7fkRE3kD2g0EWLFggTRLl5ORg27ZtAACj0QiFQgFfX/vH5oeEhMBisWB4eBgBAQEwGo0IDnb8eyAqlQpGo1FWDaGhigk9nv9ef3/EnYKVgbL6eWr9crF+92L9k0t2eNbU1MBiseDjjz/Gb37zG5SXl+OFF14AAPj4+Dj0t9lsDsvu1m+s9rFM5A9ChYcHT+gPxrmSyTwoq5+n1i+HJ29/OVi/e3lK/fcKcNnh+eijjwIAkpOTERoaip///Od45plnoFKp0N/fD6vVanf0aTQaERQUBH9/fwBfHWGaTI4bw2w2j3lESkTkyYQmjOLi4gAAXV1dUKvVsFqt6OjosOtz5xinWq12GNu8evUqBgYGHMZCiYg8nVB4tra2AgAiIiKQlJQEpVKJuro6abnFYkFDQwPS0tKktmXLluHs2bMwm81SW21tLQIDA5GSkiJaPxGRW4x72l5YWIjFixcjOjoavr6+aG1tRVVVFVasWIE5c+YAAIqLi6HX6xESEiJdJD86Oor8/HxpPbm5uaiuroZWq0VRURE6Ozuh0+mg0Wh4jScReZ1xwzM+Ph5//vOf0d3dDV9fX8yePRulpaV2F88XFxdjdHQUhw4dQm9vL+bPn4+qqio8+OCDUp+QkBAcOXIE5eXl2LBhA1QqFQoKCqDVaifnmxERTSIf2+1pcS8wkdk3T5mtG4vcv2HkqfXL4cnbXw7W716eUv+9Ztv5VCUiIgEMTyIiAQxPIiIBDE8iIgEMTyIiAbJvzyR55MykE5H345EnEZEAhicRkQCGJxGRAIYnEZEAhicRkQCGJxGRAIYnEZEAhicRkQCGJxGRAIYnEZEAhicRkQCGJxGRAIYnEZEAhicRkQCGJxGRAIYnEZEAhicRkQCGJxGRAIYnEZEAhicRkQCGJxGRAIYnEZEAhicRkQCGJxGRAIYnEZEAhicRkYBxw/Odd97Bhg0bkJaWhoULF2L16tU4deqUQ78//OEP+MEPfoD4+HisXr0a58+fd+hz7do1bNy4EQsXLkRqairKy8thsVic802IiFzIb7wOR44cQUREBLZv347Q0FA0NTVh69at6OnpQX5+PgDg9OnT2LVrFzZt2oTHHnsMJ0+exE9/+lP86U9/QkxMDADg1q1bKCwshL+/P37961/DaDRi9+7dMBqN2Lt37+R+SyIiJxs3PA8ePIiwsDDp9fe//3188cUXqKqqksJz//79yMnJwcaNGwEAKSkpuHDhAiorK6VgrKurw2effYYzZ85g9uzZX324nx9KS0uxadMmREZGOvu7ERFNmnFP278ZnLfFxsbi5s2bAIDOzk5cuXIF2dnZX6902jRkZWWhublZamtqakJ8fLwUnADw5JNPwt/f364fEZE3EJow+vDDDxEVFQUAMBgMAAC1Wm3XJyoqCr29vVLIGgwGhz4BAQGYM2eOtA4iIm8x7mn7nc6fP4/6+nq8/PLLAIC+vj4AgEqlsusXEhIiLQ8LC4PRaERwcLDD+lQqFYxGo6zPDg1VwM/PV3at4eGOnzfZgpWBTluXO+p3JtbvXqx/ck0oPLu6urB161YsX74cq1evtlvm4+Nj99pmszm039nnm/3k6OkZkN03PDwY16+bZPd3FpN50Gnrckf9zuKu7e8srN+9PKX+ewW47NP23t5eFBUVYcaMGXj11Vel9ttHmHcePd5+ffuIVKVSwWRy3Bgmk8nhqJWIyNPJCk+LxYINGzZgZGQElZWVUCgU0rLb45h3jlsaDAZMnz5dmnBSq9UOfYaHh9HZ2ekwFkpE5OnGDc9bt26hpKQEV65cweHDh/HAAw/YLZ89ezYiIyNRV1cntY2OjqKurg5paWlS27Jly/Dxxx+ju7tbanvvvfcwPDxs14+IyBuMO+b54osvorGxETt27EBfXx/a2tqkZXFxcQgICIBWq8Xzzz+PWbNmISkpCX/5y1/Q0dGB1157TeqblZWFiooKaLValJSUwGQy4ZVXXsGqVat4jScReZ1xw/PcuXMAgJdeeslhWX19PSIiIrBq1SoMDAzg8OHD0Ov1+N73vodDhw5JdxcBgL+/P37729+ivLwczz77LAICArBixQps27bNiV+HiMg1fGwTme52s4nMvrlrtu79tu7xO8kQrAwcd+b+iQWznPJZk8FTZktFsX738pT6nTLbTkREX2N4EhEJYHgSEQmY8O2Z5Dnkjq968tgokbfikScRkQCGJxGRAIYnEZEAjnnK5KzrN4loauCRJxGRAIYnEZEAhicRkQCGJxGRAIYnEZEAhicRkQCGJxGRAIYnEZEAhicRkQCGJxGRAIYnEZEAhicRkQCGJxGRAIYnEZEAhicRkQCGJxGRAIYnEZEAhicRkQCGJxGRAIYnEZEAhicRkQCGJxGRAIYnEZEAWeHZ0dGBnTt34kc/+hFiY2ORn5/v0Mdms6GiogLp6elISEhAXl4eLly44NCvvb0dBQUFSExMxNKlS/H666/DarV++29CRORCssLz8uXLaGxsRGRkJCIjI8fsU1lZCb1ej6KiIlRUVEChUECj0eD69etSn76+Pmg0Gvj4+ECv12Pjxo2oqqrC/v37nfJliIhcxU9Op8zMTDz55JMAgM2bN6Onp8du+dDQECorK1FcXIx169YBABYsWIDMzEwcPXoUW7ZsAQDU1NRgaGgIOp0OSqUSS5Ysgdlshk6nQ1FREZRKpTO/GxHRpJF15Dlt2r27tba2wmw2Izs7W2pTKBTIyMhAc3Oz1NbU1ISlS5faheTKlSsxODiIlpaWidZOROQ2TpkwMhgM8PX1dTilj4qKgsFgsOunVqvt+sycORNBQUF2/YiIPJ1TwtNoNEKhUMDX19euPSQkBBaLBcPDw1K/4OBgh/erVCoYjUZnlEJE5BKyxjzl8PHxcWiz2WwOy+7Wb6z2O4WGKuDn5ztuv9vCwx2DWlSwMtBp63L1ZzpzO3jD5zoL63cvT6/fKeGpUqnQ398Pq9Vqd/RpNBoRFBQEf39/qZ/JZHJ4v9lsHvOI9E49PQOyawoPD8b1646fJcpkHnTauuQIVgY67TOduR3kcvb2dzXW716eUv+9Atwpp+1qtRpWqxUdHR127XeOcarVaoexzatXr2JgYMBhLJSIyJM5JTyTkpKgVCpRV1cntVksFjQ0NCAtLU1qW7ZsGc6ePQuz2Sy11dbWIjAwECkpKc4ohYjIJWSdtlssFjQ2NgIArl27BrPZLAVleno6goKCUFxcDL1ej5CQEKjValRVVWF0dNTubqTc3FxUV1dDq9WiqKgInZ2d0Ol00Gg0vMaTiLyKj+32rM49dHV1Yfny5WMuq6+vR0REhHR75vHjx9Hb24v58+ejrKwMcXFxdv3b29tRXl6OtrY2qFQqPP3009BqtQ4z9WOZyBiIs8dM3m/rdtq65HDmmKccTyyY5dT1ecqYlSjW716eUv+9xjxlhaenYHhOHoanPdbvXp5S/6RPGBER/X/D8CQiEsDwJCISwPAkIhLA8CQiEsDwJCIS4LQHg3gzV1+GRETej0eeREQCGJ5ERAIYnkREAhieREQCGJ5ERAIYnkREAhieREQCGJ5ERAJ4kTwBkHejgLOf+UnkzXjkSUQkgOFJRCSA4UlEJIDhSUQkgOFJRCRgys+283FzRDQZeORJRCSA4UlEJIDhSUQkgOFJRCSA4UlEJIDhSUQkgOFJRCSA4UlEJIDhSUQkYMrfYUTOw2d+En3N5eHZ3t6OX/7yl2hra0NwcDDWrFmDTZs2wdfX19Wl0CS4HbDBykCYzIPfal0MYvJkLg3Pvr4+aDQaREdHQ6/X4/PPP8eePXswOjqKLVu2uLIUmiLkPruAQUzO5tLwrKmpwdDQEHQ6HZRKJZYsWQKz2QydToeioiIolUpXlkMejg91IU/m0vBsamrC0qVL7UJy5cqV2Lt3L1paWpCZmenKcuj/kYkG8d2GHZx5BMsxZO/m0vA0GAxYtGiRXdvMmTMRFBQEg8HA8CSP5+qjYdHPuzP8GcLO59LwNBqNCA4OdmhXqVQwGo2uLIWIBLniP5Bvhr+c4HfHUbzLZ9t9fHwc2mw225jtdwoPdwze8fqveeqRCb2HiO7NE/cpd9Tk0ovkVSoVTCaTQ7vZbB7ziJSIyFO5NDzVajUMBoNd29WrVzEwMAC1Wu3KUoiIvhWXhueyZctw9uxZmM1mqa22thaBgYFISUlxZSlERN+KS8MzNzcXAQEB0Gq1+Mc//oHf//730Ol00Gg0vMaTiLyKj81ms7nyA9vb21FeXo62tjaoVCo8/fTT0Gq1vD2TiLyKy8OTiGgqmFJPVfKmh4688847eOutt/DJJ5/AbDZj3rx5eOaZZ7Bq1SqpT35+PlpaWhze+9FHH+G+++5zZbkOTp48ie3btzu0v/DCC1i7di2Ary5BO3ToEI4fP46enh7Ex8ejrKwMsbGxri7Xwd22LfDVbcQLFy5EZmYmurvtrx988MEHce7cOVeUaKejowNvvPEG2tracPnyZSQnJ6O6utquj9zt7Y79ZLz6v/jiC1RVVeHcuXPo7OyESqXCokWLUFpaioceekjq98EHH2D9+vUO6y8qKsJzzz03afWPZcqEp7c9dOTIkSOIiIjA9u3bERoaiqamJmzduhU9PT3Iz8+X+qWmpqK0tNTuvQEBAa4u967efPNNBAYGSq9nz54t/VxZWQm9Xo9t27ZBrVajqqoKGo0Gp06dQnh4uDvKlezatctu4hIA9u/fj08//RTx8fFS26pVq+x+H/7+/i6r8ZsuX76MxsZGJCYmYmRkZMw+cra3u/aT8er/5JNP8Pe//x1r1qxBQkICbty4gQMHDmDt2rV4++23cf/999v137t3r92/tW8GrMvYpoiKigpbcnKyzWQySW2VlZW2hIQEuzZPcePGDYe20tJSW0ZGhvR63bp1Nq1W68qyZDtx4oQtJibGZjabx1w+ODhoS0pKsh04cEBq6+/vt6Wmptr27dvnqjJlGxoasj3++OO2nTt3Sm0ZGRm23bt3u7Gqr1mtVulnrVZrW7dund1yudvbXfvJePX39fXZRkZG7NoMBoMtJibGdvLkSantn//8py0mJsZ26dKlSatVrinzJPm7PXRkcHDwrqdn7hQWFubQFhsbi5s3b7qhGudrbW2F2WxGdna21KZQKJCRkYHm5mY3Vja25uZm9PX12Q2beJJp0+69q8rd3u7aT8arX6VSwc/P/kR43rx5CAoKwo0bNyatrm9jyoSnwWBwuND+mw8d8QYffvghoqKi7NrOnj2LxMREJCYmorCwEBcvXnRTdWN76qmnEBcXh6ysLNTU1EjtBoMBvr6+iIyMtOsfFRXlkb+P2tpaPPTQQ0hOTrZrP3HiBObPn4/HHnsMmzdvdhgD9RRyt7c37ScXL16ExWJBdHS0w7KCggLExsYiMzMTer0eVqvV5fVNmTFPb3/oyPnz51FfX4+XX35Zanv88ceRk5ODuXPnoru7GxUVFcjLy8Nbb72FiIgIN1YLhIeHo6SkBAkJCbBarTh9+jR27dqFwcFBaDQaGI1GKBQKh0mIkJAQWCwWDA8Pe8zYrcViwXvvvYef/OQnds9YyMzMxIIFC/Dd734Xn332GXQ6HfLy8vD222973O3Ecre3t+wno6OjeOmllxAZGYklS5ZI7cHBwSguLkZycjL8/f3R0NCAAwcO4ObNmygrK3NpjVMmPIFv99ARd+rq6sLWrVuxfPlyrF69WmrfvHmz9HNycjIWL16M7OxsvPnmm9ixY4c7SpWkpaUhLS1Nep2eno7h4WEcPHhQmg292+/jbsvcpaGhAQMDA1i5cqVd+zd3xuTkZCxcuBA5OTk4ceIENBqNi6scn9zt7Q37yWuvvYa2tjYcPXrUbpIuLi4OcXFx0uvFixcjICAAR44cwc9+9rMxh8Mmy5Q5bffWh4709vaiqKgIM2bMwKuvvnrPvuHh4UhKSsKnn37qouomJisrC729veju7oZKpUJ/f7/D6ZTRaERQUJDbZq3Hcvr0acydO9duln0sMTExmDdvnkduf7nb2xv2k2PHjuGNN97Anj17kJiYOG7/rKws3Lp1C5cuXXJBdV+bMuHpjQ8dsVgs2LBhA0ZGRlBZWQmFQiHrfZ50hHA3arUaVqsVHR0ddu1jjbm5k8lkQlNTk8NR57144vaXu709fT9599138atf/QrPP/88VqxYMaH3uvr3MmXC09seOnLr1i2UlJTgypUrOHz4MB544IFx3/Pll1+itbUVjz76qAsqnLgzZ84gNDQUs2bNQlJSEpRKJerq6qTlFosFDQ0Ndqf77va3v/0Nw8PDsmbZ//3vf+M///mPR25/udvbk/eTDz74AM899xzy8vJQWFgo+31nzpyBn58fHn744UmsztGUGfPMzc1FdXU1tFotioqK0NnZ6dEPHXnxxRfR2NiIHTt2oK+vD21tbdKyuLg4GAwG7Nu3Dz/84Q8xc+ZMXL16FYcOHcK0adNQUFDgxsq/otVqER8fj4cffhijo6Oora1FbW0tysrKMG3aNNx3330oLi6GXq9HSEiIdNH26Oio3UXn7nb69Gk88sgjDlc5vP/++/jrX/+KJ554At/5zndgMBhw8OBBzJgxw25c2lUsFgsaGxsBANeuXYPZbJaCMj09HUFBQbK2t7v2k/Hq/+9//4uNGzdCrVZjxYoVdvtDWFgY5syZA+CrmxvCwsIQHx8Pf39/NDY24tixYygoKEBoaOik1T+WKXVvuzc9dGSsW/9uq6+vh7+/P8rKynDhwgX09vbi/vvvR0pKCp599lmHHd0d9u3bh3fffRf/+9//YLPZEB0djfXr1yMnJ0fqY7PZUFFRgePHj6O3txfz589HWVmZ3YC/O928eRNpaWkoKSlBcXGx3bKLFy/ilVdewaVLl2AymTB9+nSkpaVhy5YtbrmbpaurC8uXLx9zWX19PSIiImRvb3fsJ+PV39LSMubtvgDw4x//GLt37wYA/O53v8Mf//hHdHV1YWRkBHPnzsWaNWuwfv36ca8ldbYpFZ5ERK4yZcY8iYhcieFJRCSA4UlEJIDhSUQkgOFJRCSA4UlEJIDhSUQkgOFJRCSA4UlEJOD/ACxkaTz39O5PAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 360x216 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "task_name = 'clintox'\n",
    "tasks = [\n",
    "    'FDA_APPROVED','CT_TOX'\n",
    "]\n",
    "raw_filename = \"../data/clintox.csv\"\n",
    "feature_filename = raw_filename.replace('.csv','.pickle')\n",
    "filename = raw_filename.replace('.csv','')\n",
    "prefix_filename = raw_filename.split('/')[-1].replace('.csv','')\n",
    "smiles_tasks_df = pd.read_csv(raw_filename)\n",
    "smilesList = smiles_tasks_df.smiles.values\n",
    "print(\"number of all smiles: \",len(smilesList))\n",
    "atom_num_dist = []\n",
    "remained_smiles = []\n",
    "canonical_smiles_list = []\n",
    "for smiles in smilesList:\n",
    "    try:        \n",
    "        mol = Chem.MolFromSmiles(smiles)\n",
    "        atom_num_dist.append(len(mol.GetAtoms()))\n",
    "        remained_smiles.append(smiles)\n",
    "        canonical_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True))\n",
    "    except:\n",
    "        print(\"not successfully processed smiles: \", smiles)\n",
    "        pass\n",
    "print(\"number of successfully processed smiles: \", len(remained_smiles))\n",
    "smiles_tasks_df = smiles_tasks_df[smiles_tasks_df[\"smiles\"].isin(remained_smiles)]\n",
    "# print(smiles_tasks_df)\n",
    "smiles_tasks_df['cano_smiles'] =canonical_smiles_list\n",
    "\n",
    "plt.figure(figsize=(5, 3))\n",
    "sns.set(font_scale=1.5)\n",
    "ax = sns.distplot(atom_num_dist, bins=28, kde=False)\n",
    "plt.tight_layout()\n",
    "# plt.savefig(\"atom_num_dist_\"+prefix_filename+\".png\",dpi=200)\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# print(len([i for i in atom_num_dist if i<51]),len([i for i in atom_num_dist if i>50]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_seed = 888\n",
    "start_time = str(time.ctime()).replace(':','-').replace(' ','_')\n",
    "start = time.time()\n",
    "\n",
    "batch_size = 100\n",
    "epochs = 800\n",
    "p_dropout = 0.5\n",
    "fingerprint_dim = 200\n",
    "\n",
    "radius = 3\n",
    "T = 3\n",
    "weight_decay = 3 # also known as l2_regularization_lambda\n",
    "learning_rate = 3.5\n",
    "per_task_output_units_num = 2 # for classification model with 2 classes\n",
    "output_units_num = len(tasks) * per_task_output_units_num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>FDA_APPROVED</th>\n",
       "      <th>CT_TOX</th>\n",
       "      <th>cano_smiles</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>[Se]</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>[Se]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>C(#N)[Fe-2](C#N)(C#N)(C#N)(C#N)N=O</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>N#C[Fe-2](C#N)(C#N)(C#N)(C#N)N=O</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                smiles  FDA_APPROVED  CT_TOX  \\\n",
       "12                                [Se]             0       1   \n",
       "20  C(#N)[Fe-2](C#N)(C#N)(C#N)(C#N)N=O             1       0   \n",
       "\n",
       "                         cano_smiles  \n",
       "12                              [Se]  \n",
       "20  N#C[Fe-2](C#N)(C#N)(C#N)(C#N)N=O  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "smilesList = [smiles for smiles in canonical_smiles_list if len(Chem.MolFromSmiles(smiles).GetAtoms())<151]\n",
    "\n",
    "if os.path.isfile(feature_filename):\n",
    "    feature_dicts = pickle.load(open(feature_filename, \"rb\" ))\n",
    "else:\n",
    "    feature_dicts = save_smiles_dicts(smilesList,filename)\n",
    "# feature_dicts = get_smiles_dicts(smilesList)\n",
    "\n",
    "remained_df = smiles_tasks_df[smiles_tasks_df[\"cano_smiles\"].isin(feature_dicts['smiles_to_atom_mask'].keys())]\n",
    "uncovered_df = smiles_tasks_df.drop(remained_df.index)\n",
    "uncovered_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = []\n",
    "for i,task in enumerate(tasks):    \n",
    "    negative_df = remained_df[remained_df[task] == 0][[\"smiles\",task]]\n",
    "    positive_df = remained_df[remained_df[task] == 1][[\"smiles\",task]]\n",
    "    weights.append([(positive_df.shape[0]+negative_df.shape[0])/negative_df.shape[0],\\\n",
    "                    (positive_df.shape[0]+negative_df.shape[0])/positive_df.shape[0]])\n",
    "\n",
    "test_df = remained_df.sample(frac=1/10, random_state=3) # test set\n",
    "training_data = remained_df.drop(test_df.index) # training data\n",
    "\n",
    "# training data is further divided into validation set and train set\n",
    "valid_df = training_data.sample(frac=1/9, random_state=3) # validation set\n",
    "train_df = training_data.drop(valid_df.index) # train set\n",
    "train_df = train_df.reset_index(drop=True)\n",
    "valid_df = valid_df.reset_index(drop=True)\n",
    "test_df = test_df.reset_index(drop=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1146008\n",
      "atom_fc.weight torch.Size([200, 39])\n",
      "atom_fc.bias torch.Size([200])\n",
      "neighbor_fc.weight torch.Size([200, 49])\n",
      "neighbor_fc.bias torch.Size([200])\n",
      "GRUCell.0.weight_ih torch.Size([600, 200])\n",
      "GRUCell.0.weight_hh torch.Size([600, 200])\n",
      "GRUCell.0.bias_ih torch.Size([600])\n",
      "GRUCell.0.bias_hh torch.Size([600])\n",
      "GRUCell.1.weight_ih torch.Size([600, 200])\n",
      "GRUCell.1.weight_hh torch.Size([600, 200])\n",
      "GRUCell.1.bias_ih torch.Size([600])\n",
      "GRUCell.1.bias_hh torch.Size([600])\n",
      "GRUCell.2.weight_ih torch.Size([600, 200])\n",
      "GRUCell.2.weight_hh torch.Size([600, 200])\n",
      "GRUCell.2.bias_ih torch.Size([600])\n",
      "GRUCell.2.bias_hh torch.Size([600])\n",
      "align.0.weight torch.Size([1, 400])\n",
      "align.0.bias torch.Size([1])\n",
      "align.1.weight torch.Size([1, 400])\n",
      "align.1.bias torch.Size([1])\n",
      "align.2.weight torch.Size([1, 400])\n",
      "align.2.bias torch.Size([1])\n",
      "attend.0.weight torch.Size([200, 200])\n",
      "attend.0.bias torch.Size([200])\n",
      "attend.1.weight torch.Size([200, 200])\n",
      "attend.1.bias torch.Size([200])\n",
      "attend.2.weight torch.Size([200, 200])\n",
      "attend.2.bias torch.Size([200])\n",
      "mol_GRUCell.weight_ih torch.Size([600, 200])\n",
      "mol_GRUCell.weight_hh torch.Size([600, 200])\n",
      "mol_GRUCell.bias_ih torch.Size([600])\n",
      "mol_GRUCell.bias_hh torch.Size([600])\n",
      "mol_align.weight torch.Size([1, 400])\n",
      "mol_align.bias torch.Size([1])\n",
      "mol_attend.weight torch.Size([200, 200])\n",
      "mol_attend.bias torch.Size([200])\n",
      "output.weight torch.Size([4, 200])\n",
      "output.bias torch.Size([4])\n"
     ]
    }
   ],
   "source": [
    "x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array([canonical_smiles_list[0]],feature_dicts)\n",
    "num_atom_features = x_atom.shape[-1]\n",
    "num_bond_features = x_bonds.shape[-1]\n",
    "\n",
    "loss_function = [nn.CrossEntropyLoss(torch.Tensor(weight),reduction='mean') for weight in weights]\n",
    "model = Fingerprint(radius, T, num_atom_features,num_bond_features,\n",
    "            fingerprint_dim, output_units_num, p_dropout)\n",
    "model.cuda()\n",
    "# tensorboard = SummaryWriter(log_dir=\"runs/\"+start_time+\"_\"+prefix_filename+\"_\"+str(fingerprint_dim)+\"_\"+str(p_dropout))\n",
    "\n",
    "# optimizer = optim.Adam(model.parameters(), learning_rate, weight_decay=weight_decay)\n",
    "optimizer = optim.Adam(model.parameters(), 10**-learning_rate, weight_decay=10**-weight_decay)\n",
    "model_parameters = filter(lambda p: p.requires_grad, model.parameters())\n",
    "params = sum([np.prod(p.size()) for p in model_parameters])\n",
    "print(params)\n",
    "for name, param in model.named_parameters():\n",
    "    if param.requires_grad:\n",
    "        print(name, param.data.shape)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, dataset, optimizer, loss_function):\n",
    "    model.train()\n",
    "    np.random.seed(epoch)\n",
    "    valList = np.arange(0,dataset.shape[0])\n",
    "    #shuffle them\n",
    "    np.random.shuffle(valList)\n",
    "    batch_list = []\n",
    "    for i in range(0, dataset.shape[0], batch_size):\n",
    "        batch = valList[i:i+batch_size]\n",
    "        batch_list.append(batch)   \n",
    "    for counter, train_batch in enumerate(batch_list):\n",
    "        batch_df = dataset.loc[train_batch,:]\n",
    "        smiles_list = batch_df.cano_smiles.values\n",
    "        \n",
    "        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)\n",
    "        atoms_prediction, mol_prediction = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask))\n",
    "#         print(torch.Tensor(x_atom).size(),torch.Tensor(x_bonds).size(),torch.cuda.LongTensor(x_atom_index).size(),torch.cuda.LongTensor(x_bond_index).size(),torch.Tensor(x_mask).size())\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss = 0.0\n",
    "        for i,task in enumerate(tasks):\n",
    "            y_pred = mol_prediction[:, i * per_task_output_units_num:(i + 1) *\n",
    "                                    per_task_output_units_num]\n",
    "            y_val = batch_df[task].values\n",
    "\n",
    "            validInds = np.where((y_val==0) | (y_val==1))[0]\n",
    "#             validInds = np.where(y_val != -1)[0]\n",
    "            if len(validInds) == 0:\n",
    "                continue\n",
    "            y_val_adjust = np.array([y_val[v] for v in validInds]).astype(float)\n",
    "            validInds = torch.cuda.LongTensor(validInds).squeeze()\n",
    "            y_pred_adjust = torch.index_select(y_pred, 0, validInds)\n",
    "\n",
    "            loss += loss_function[i](\n",
    "                y_pred_adjust,\n",
    "                torch.cuda.LongTensor(y_val_adjust))\n",
    "        # Step 5. Do the backward pass and update the gradient\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "def eval(model, dataset):\n",
    "    model.eval()\n",
    "    y_val_list = {}\n",
    "    y_pred_list = {}\n",
    "    losses_list = []\n",
    "    valList = np.arange(0,dataset.shape[0])\n",
    "    batch_list = []\n",
    "    for i in range(0, dataset.shape[0], batch_size):\n",
    "        batch = valList[i:i+batch_size]\n",
    "        batch_list.append(batch)   \n",
    "    for counter, eval_batch in enumerate(batch_list):\n",
    "        batch_df = dataset.loc[eval_batch,:]\n",
    "        smiles_list = batch_df.cano_smiles.values\n",
    "        \n",
    "        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)\n",
    "        atoms_prediction, mol_prediction = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask))\n",
    "        atom_pred = atoms_prediction.data[:,:,1].unsqueeze(2).cpu().numpy()\n",
    "        for i,task in enumerate(tasks):\n",
    "            y_pred = mol_prediction[:, i * per_task_output_units_num:(i + 1) *\n",
    "                                    per_task_output_units_num]\n",
    "            y_val = batch_df[task].values\n",
    "\n",
    "            validInds = np.where((y_val==0) | (y_val==1))[0]\n",
    "#             validInds = np.where((y_val=='0') | (y_val=='1'))[0]\n",
    "#             print(validInds)\n",
    "            if len(validInds) == 0:\n",
    "                continue\n",
    "            y_val_adjust = np.array([y_val[v] for v in validInds]).astype(float)\n",
    "            validInds = torch.cuda.LongTensor(validInds).squeeze()\n",
    "            y_pred_adjust = torch.index_select(y_pred, 0, validInds)\n",
    "#             print(validInds)\n",
    "            loss = loss_function[i](\n",
    "                y_pred_adjust,\n",
    "                torch.cuda.LongTensor(y_val_adjust))\n",
    "#             print(y_pred_adjust)\n",
    "            y_pred_adjust = F.softmax(y_pred_adjust,dim=-1).data.cpu().numpy()[:,1]\n",
    "            losses_list.append(loss.cpu().detach().numpy())\n",
    "            try:\n",
    "                y_val_list[i].extend(y_val_adjust)\n",
    "                y_pred_list[i].extend(y_pred_adjust)\n",
    "            except:\n",
    "                y_val_list[i] = []\n",
    "                y_pred_list[i] = []\n",
    "                y_val_list[i].extend(y_val_adjust)\n",
    "                y_pred_list[i].extend(y_pred_adjust)\n",
    "                \n",
    "    eval_roc = [roc_auc_score(y_val_list[i], y_pred_list[i]) for i in range(len(tasks))]\n",
    "#     eval_prc = [auc(precision_recall_curve(y_val_list[i], y_pred_list[i])[1],precision_recall_curve(y_val_list[i], y_pred_list[i])[0]) for i in range(len(tasks))]\n",
    "#     eval_precision = [precision_score(y_val_list[i],\n",
    "#                                      (np.array(y_pred_list[i]) > 0.5).astype(int)) for i in range(len(tasks))]\n",
    "#     eval_recall = [recall_score(y_val_list[i],\n",
    "#                                (np.array(y_pred_list[i]) > 0.5).astype(int)) for i in range(len(tasks))]\n",
    "    eval_loss = np.array(losses_list).mean()\n",
    "    \n",
    "    return eval_roc, eval_loss #eval_prc, eval_precision, eval_recall, \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH:\t0\n",
      "train_roc:[0.3443868823615659, 0.3605198776758409]\n",
      "valid_roc:[0.3638731060606061, 0.35765603951504266]\n",
      "train_roc_mean:0.3524533800187034\n",
      "valid_roc_mean:0.36076457278782437\n",
      "\n",
      "EPOCH:\t1\n",
      "train_roc:[0.6697619862176825, 0.6461060142711519]\n",
      "valid_roc:[0.6583806818181818, 0.6459362370902559]\n",
      "train_roc_mean:0.6579340002444172\n",
      "valid_roc_mean:0.6521584594542189\n",
      "\n",
      "EPOCH:\t2\n",
      "train_roc:[0.6748937002101559, 0.655902140672783]\n",
      "valid_roc:[0.657907196969697, 0.7029636281993714]\n",
      "train_roc_mean:0.6653979204414695\n",
      "valid_roc_mean:0.6804354125845342\n",
      "\n",
      "EPOCH:\t3\n",
      "train_roc:[0.7228996627730805, 0.6761264016309889]\n",
      "valid_roc:[0.6706912878787878, 0.7168837000449034]\n",
      "train_roc_mean:0.6995130322020346\n",
      "valid_roc_mean:0.6937874939618456\n",
      "\n",
      "EPOCH:\t4\n",
      "train_roc:[0.7082131860612875, 0.6798267074413864]\n",
      "valid_roc:[0.6758996212121212, 0.6850022451728783]\n",
      "train_roc_mean:0.6940199467513369\n",
      "valid_roc_mean:0.6804509331924997\n",
      "\n",
      "EPOCH:\t5\n",
      "train_roc:[0.6830677874981672, 0.6382059123343526]\n",
      "valid_roc:[0.6550662878787878, 0.6885945217781769]\n",
      "train_roc_mean:0.6606368499162599\n",
      "valid_roc_mean:0.6718304048284824\n",
      "\n",
      "EPOCH:\t6\n",
      "train_roc:[0.7260642197351059, 0.6824260958205912]\n",
      "valid_roc:[0.6858428030303032, 0.7083520431073193]\n",
      "train_roc_mean:0.7042451577778486\n",
      "valid_roc_mean:0.6970974230688112\n",
      "\n",
      "EPOCH:\t7\n",
      "train_roc:[0.7218977567078833, 0.6741080530071356]\n",
      "valid_roc:[0.6754261363636364, 0.6971261787157611]\n",
      "train_roc_mean:0.6980029048575094\n",
      "valid_roc_mean:0.6862761575396987\n",
      "\n",
      "EPOCH:\t8\n",
      "train_roc:[0.7349836273886907, 0.6894087665647298]\n",
      "valid_roc:[0.6749526515151515, 0.7047597665020207]\n",
      "train_roc_mean:0.7121961969767103\n",
      "valid_roc_mean:0.689856209008586\n",
      "\n",
      "EPOCH:\t9\n",
      "train_roc:[0.7377938517179023, 0.6943323139653416]\n",
      "valid_roc:[0.6934185606060607, 0.7128423888639425]\n",
      "train_roc_mean:0.7160630828416219\n",
      "valid_roc_mean:0.7031304747350016\n",
      "\n",
      "EPOCH:\t10\n",
      "train_roc:[0.7422291188113973, 0.6997145769622833]\n",
      "valid_roc:[0.6801609848484848, 0.7110462505612931]\n",
      "train_roc_mean:0.7209718478868403\n",
      "valid_roc_mean:0.695603617704889\n",
      "\n",
      "EPOCH:\t11\n",
      "train_roc:[0.7458946288060212, 0.7017125382262996]\n",
      "valid_roc:[0.6867897727272726, 0.7191288729232151]\n",
      "train_roc_mean:0.7238035835161605\n",
      "valid_roc_mean:0.7029593228252439\n",
      "\n",
      "EPOCH:\t12\n",
      "train_roc:[0.7539098773275988, 0.7121202854230377]\n",
      "valid_roc:[0.6886837121212122, 0.7114952851369556]\n",
      "train_roc_mean:0.7330150813753182\n",
      "valid_roc_mean:0.7000894986290839\n",
      "\n",
      "EPOCH:\t13\n",
      "train_roc:[0.7740213088314354, 0.7281957186544342]\n",
      "valid_roc:[0.7123579545454546, 0.7312528064660979]\n",
      "train_roc_mean:0.7511085137429347\n",
      "valid_roc_mean:0.7218053805057763\n",
      "\n",
      "EPOCH:\t14\n",
      "train_roc:[0.7703557988368115, 0.7275229357798165]\n",
      "valid_roc:[0.7270359848484848, 0.7379883251010327]\n",
      "train_roc_mean:0.748939367308314\n",
      "valid_roc_mean:0.7325121549747587\n",
      "\n",
      "EPOCH:\t15\n",
      "train_roc:[0.7922266751380675, 0.752803261977574]\n",
      "valid_roc:[0.7265625, 0.7451728783116299]\n",
      "train_roc_mean:0.7725149685578208\n",
      "valid_roc_mean:0.7358676891558149\n",
      "\n",
      "EPOCH:\t16\n",
      "train_roc:[0.7870583060456479, 0.74506625891947]\n",
      "valid_roc:[0.7530776515151515, 0.7433767400089807]\n",
      "train_roc_mean:0.7660622824825589\n",
      "valid_roc_mean:0.7482271957620661\n",
      "\n",
      "EPOCH:\t17\n",
      "train_roc:[0.8130834270074777, 0.7744546381243628]\n",
      "valid_roc:[0.7739109848484849, 0.7689717108217333]\n",
      "train_roc_mean:0.7937690325659202\n",
      "valid_roc_mean:0.7714413478351091\n",
      "\n",
      "EPOCH:\t18\n",
      "train_roc:[0.8248130589902742, 0.7951274209989806]\n",
      "valid_roc:[0.7331912878787878, 0.7456219128872923]\n",
      "train_roc_mean:0.8099702399946274\n",
      "valid_roc_mean:0.7394066003830401\n",
      "\n",
      "EPOCH:\t19\n",
      "train_roc:[0.8164679145691804, 0.7966258919469928]\n",
      "valid_roc:[0.704782196969697, 0.7240682532555007]\n",
      "train_roc_mean:0.8065469032580866\n",
      "valid_roc_mean:0.7144252251125989\n",
      "\n",
      "EPOCH:\t20\n",
      "train_roc:[0.8353452910414936, 0.8107237512742099]\n",
      "valid_roc:[0.7961647727272727, 0.7972608890884598]\n",
      "train_roc_mean:0.8230345211578518\n",
      "valid_roc_mean:0.7967128309078663\n",
      "\n",
      "EPOCH:\t21\n",
      "train_roc:[0.8492986657543621, 0.8250560652395516]\n",
      "valid_roc:[0.7691761363636364, 0.7932195779074989]\n",
      "train_roc_mean:0.8371773654969568\n",
      "valid_roc_mean:0.7811978571355676\n",
      "\n",
      "EPOCH:\t22\n",
      "train_roc:[0.858743463173843, 0.8348623853211009]\n",
      "valid_roc:[0.7956912878787878, 0.8071396497530311]\n",
      "train_roc_mean:0.8468029242474719\n",
      "valid_roc_mean:0.8014154688159094\n",
      "\n",
      "EPOCH:\t23\n",
      "train_roc:[0.8656590586970334, 0.8429153924566768]\n",
      "valid_roc:[0.8070549242424243, 0.8201616524472385]\n",
      "train_roc_mean:0.8542872255768551\n",
      "valid_roc_mean:0.8136082883448315\n",
      "\n",
      "EPOCH:\t24\n",
      "train_roc:[0.8742241337178047, 0.8508460754332314]\n",
      "valid_roc:[0.8075284090909092, 0.8246519982038617]\n",
      "train_roc_mean:0.8625351045755181\n",
      "valid_roc_mean:0.8160902036473854\n",
      "\n",
      "EPOCH:\t25\n",
      "train_roc:[0.8753726601827867, 0.8526401630988787]\n",
      "valid_roc:[0.8151041666666666, 0.8309384822631343]\n",
      "train_roc_mean:0.8640064116408327\n",
      "valid_roc_mean:0.8230213244649005\n",
      "\n",
      "EPOCH:\t26\n",
      "train_roc:[0.8791481354772495, 0.8549133537206931]\n",
      "valid_roc:[0.8293087121212122, 0.845756623259991]\n",
      "train_roc_mean:0.8670307445989713\n",
      "valid_roc_mean:0.8375326676906016\n",
      "\n",
      "EPOCH:\t27\n",
      "train_roc:[0.8831679781046871, 0.8606523955147809]\n",
      "valid_roc:[0.8312026515151516, 0.8533902110462506]\n",
      "train_roc_mean:0.871910186809734\n",
      "valid_roc_mean:0.8422964312807011\n",
      "\n",
      "EPOCH:\t28\n",
      "train_roc:[0.8857093983676261, 0.8631294597349642]\n",
      "valid_roc:[0.8397253787878787, 0.8533902110462506]\n",
      "train_roc_mean:0.8744194290512952\n",
      "valid_roc_mean:0.8465577949170646\n",
      "\n",
      "EPOCH:\t29\n",
      "train_roc:[0.8811519476076438, 0.8545056065239551]\n",
      "valid_roc:[0.8392518939393939, 0.8569824876515492]\n",
      "train_roc_mean:0.8678287770657995\n",
      "valid_roc_mean:0.8481171907954715\n",
      "\n",
      "EPOCH:\t30\n",
      "train_roc:[0.8781340110454035, 0.8534454638124364]\n",
      "valid_roc:[0.829782196969697, 0.852492141894926]\n",
      "train_roc_mean:0.8657897374289198\n",
      "valid_roc_mean:0.8411371694323115\n",
      "\n",
      "EPOCH:\t31\n",
      "train_roc:[0.8911221347930207, 0.8714169215086647]\n",
      "valid_roc:[0.8269412878787878, 0.8520431073192636]\n",
      "train_roc_mean:0.8812695281508427\n",
      "valid_roc_mean:0.8394921975990257\n",
      "\n",
      "EPOCH:\t32\n",
      "train_roc:[0.8908655490933971, 0.867217125382263]\n",
      "valid_roc:[0.8491950757575758, 0.8578805568028739]\n",
      "train_roc_mean:0.8790413372378301\n",
      "valid_roc_mean:0.8535378162802248\n",
      "\n",
      "EPOCH:\t33\n",
      "train_roc:[0.8982820976491862, 0.8763710499490316]\n",
      "valid_roc:[0.833096590909091, 0.8551863493488999]\n",
      "train_roc_mean:0.887326573799109\n",
      "valid_roc_mean:0.8441414701289955\n",
      "\n",
      "EPOCH:\t34\n",
      "train_roc:[0.8996994281804408, 0.8806320081549439]\n",
      "valid_roc:[0.8288352272727273, 0.8538392456219128]\n",
      "train_roc_mean:0.8901657181676923\n",
      "valid_roc_mean:0.8413372364473201\n",
      "\n",
      "EPOCH:\t35\n",
      "train_roc:[0.9005913689457994, 0.8817940876656472]\n",
      "valid_roc:[0.8477746212121212, 0.8740458015267175]\n",
      "train_roc_mean:0.8911927283057233\n",
      "valid_roc_mean:0.8609102113694194\n",
      "\n",
      "EPOCH:\t36\n",
      "train_roc:[0.9056253360050828, 0.8873598369011213]\n",
      "valid_roc:[0.8430397727272727, 0.8709025594970813]\n",
      "train_roc_mean:0.896492586453102\n",
      "valid_roc_mean:0.856971166112177\n",
      "\n",
      "EPOCH:\t37\n",
      "train_roc:[0.8955329651532182, 0.8788175331294599]\n",
      "valid_roc:[0.8747632575757576, 0.8848226313426134]\n",
      "train_roc_mean:0.887175249141339\n",
      "valid_roc_mean:0.8797929444591854\n",
      "\n",
      "EPOCH:\t38\n",
      "train_roc:[0.9070060114363911, 0.8870948012232416]\n",
      "valid_roc:[0.8709753787878789, 0.8929052537045353]\n",
      "train_roc_mean:0.8970504063298164\n",
      "valid_roc_mean:0.8819403162462072\n",
      "\n",
      "EPOCH:\t39\n",
      "train_roc:[0.9128097355945457, 0.8948623853211009]\n",
      "valid_roc:[0.8581912878787878, 0.8821284238886394]\n",
      "train_roc_mean:0.9038360604578233\n",
      "valid_roc_mean:0.8701598558837136\n",
      "\n",
      "EPOCH:\t40\n",
      "train_roc:[0.917807047553883, 0.90217125382263]\n",
      "valid_roc:[0.8572443181818181, 0.8812303547373147]\n",
      "train_roc_mean:0.9099891506882565\n",
      "valid_roc_mean:0.8692373364595665\n",
      "\n",
      "EPOCH:\t41\n",
      "train_roc:[0.9226455207467866, 0.9074108053007136]\n",
      "valid_roc:[0.8719223484848485, 0.8929052537045353]\n",
      "train_roc_mean:0.9150281630237501\n",
      "valid_roc_mean:0.8824138010946919\n",
      "\n",
      "EPOCH:\t42\n",
      "train_roc:[0.9257001124089732, 0.91072375127421]\n",
      "valid_roc:[0.8813920454545454, 0.9018859452177818]\n",
      "train_roc_mean:0.9182119318415916\n",
      "valid_roc_mean:0.8916389953361636\n",
      "\n",
      "EPOCH:\t43\n",
      "train_roc:[0.9289257612042422, 0.9127930682976554]\n",
      "valid_roc:[0.8856534090909091, 0.9086214638527167]\n",
      "train_roc_mean:0.9208594147509488\n",
      "valid_roc_mean:0.8971374364718129\n",
      "\n",
      "EPOCH:\t44\n",
      "train_roc:[0.9086310542006745, 0.8982976554536188]\n",
      "valid_roc:[0.8974905303030303, 0.892456219128873]\n",
      "train_roc_mean:0.9034643548271466\n",
      "valid_roc_mean:0.8949733747159516\n",
      "\n",
      "EPOCH:\t45\n",
      "train_roc:[0.9368921362592249, 0.9227013251783893]\n",
      "valid_roc:[0.892282196969697, 0.909070498428379]\n",
      "train_roc_mean:0.9297967307188071\n",
      "valid_roc_mean:0.900676347699038\n",
      "\n",
      "EPOCH:\t46\n",
      "train_roc:[0.930379746835443, 0.9144546381243628]\n",
      "valid_roc:[0.8638731060606061, 0.8825774584643017]\n",
      "train_roc_mean:0.922417192479903\n",
      "valid_roc_mean:0.8732252822624539\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH:\t47\n",
      "train_roc:[0.9281560041053712, 0.9161467889908257]\n",
      "valid_roc:[0.8446969696969697, 0.8686573866187696]\n",
      "train_roc_mean:0.9221513965480985\n",
      "valid_roc_mean:0.8566771781578697\n",
      "\n",
      "EPOCH:\t48\n",
      "train_roc:[0.9369287913591711, 0.9214271151885831]\n",
      "valid_roc:[0.8837594696969697, 0.9054782218230804]\n",
      "train_roc_mean:0.9291779532738771\n",
      "valid_roc_mean:0.894618845760025\n",
      "\n",
      "EPOCH:\t49\n",
      "train_roc:[0.9403132789208738, 0.9290316004077471]\n",
      "valid_roc:[0.8780776515151515, 0.9036820835204312]\n",
      "train_roc_mean:0.9346724396643105\n",
      "valid_roc_mean:0.8908798675177914\n",
      "\n",
      "EPOCH:\t50\n",
      "train_roc:[0.9445774888812863, 0.9327013251783893]\n",
      "valid_roc:[0.892282196969697, 0.9126627750336775]\n",
      "train_roc_mean:0.9386394070298378\n",
      "valid_roc_mean:0.9024724860016873\n",
      "\n",
      "EPOCH:\t51\n",
      "train_roc:[0.9469722887444407, 0.9380530071355759]\n",
      "valid_roc:[0.879498106060606, 0.8991917377638078]\n",
      "train_roc_mean:0.9425126479400083\n",
      "valid_roc_mean:0.8893449219122069\n",
      "\n",
      "EPOCH:\t52\n",
      "train_roc:[0.9522872782366453, 0.9423343527013253]\n",
      "valid_roc:[0.9050662878787878, 0.9158060170633139]\n",
      "train_roc_mean:0.9473108154689853\n",
      "valid_roc_mean:0.9104361524710509\n",
      "\n",
      "EPOCH:\t53\n",
      "train_roc:[0.9549264454327745, 0.9460652395514781]\n",
      "valid_roc:[0.8903882575757576, 0.9059272563987427]\n",
      "train_roc_mean:0.9504958424921264\n",
      "valid_roc_mean:0.8981577569872501\n",
      "\n",
      "EPOCH:\t54\n",
      "train_roc:[0.9512364987048532, 0.9392252803261978]\n",
      "valid_roc:[0.904592803030303, 0.9234396048495734]\n",
      "train_roc_mean:0.9452308895155255\n",
      "valid_roc_mean:0.9140162039399382\n",
      "\n",
      "EPOCH:\t55\n",
      "train_roc:[0.9589218513269147, 0.9519266055045872]\n",
      "valid_roc:[0.8922821969696969, 0.9059272563987426]\n",
      "train_roc_mean:0.9554242284157509\n",
      "valid_roc_mean:0.8991047266842198\n",
      "\n",
      "EPOCH:\t56\n",
      "train_roc:[0.9602414349249793, 0.9519062181447502]\n",
      "valid_roc:[0.907907196969697, 0.9234396048495734]\n",
      "train_roc_mean:0.9560738265348647\n",
      "valid_roc_mean:0.9156734009096352\n",
      "\n",
      "EPOCH:\t57\n",
      "train_roc:[0.9640291285860906, 0.9585626911314984]\n",
      "valid_roc:[0.9050662878787878, 0.9131118096093399]\n",
      "train_roc_mean:0.9612959098587945\n",
      "valid_roc_mean:0.9090890487440639\n",
      "\n",
      "EPOCH:\t58\n",
      "train_roc:[0.9642368408191194, 0.9576350662589195]\n",
      "valid_roc:[0.9102746212121212, 0.925684777727885]\n",
      "train_roc_mean:0.9609359535390194\n",
      "valid_roc_mean:0.9179796994700031\n",
      "\n",
      "EPOCH:\t59\n",
      "train_roc:[0.9647377938517179, 0.9564322120285422]\n",
      "valid_roc:[0.9036458333333334, 0.907723394701392]\n",
      "train_roc_mean:0.96058500294013\n",
      "valid_roc_mean:0.9056846140173627\n",
      "\n",
      "EPOCH:\t60\n",
      "train_roc:[0.9657030448169689, 0.9573700305810398]\n",
      "valid_roc:[0.9112215909090908, 0.9202963628199372]\n",
      "train_roc_mean:0.9615365376990044\n",
      "valid_roc_mean:0.915758976864514\n",
      "\n",
      "EPOCH:\t61\n",
      "train_roc:[0.9644689897854454, 0.9588685015290519]\n",
      "valid_roc:[0.8984375, 0.9149079479119893]\n",
      "train_roc_mean:0.9616687456572486\n",
      "valid_roc_mean:0.9066727239559946\n",
      "\n",
      "EPOCH:\t62\n",
      "train_roc:[0.9680367528468794, 0.9602344546381243]\n",
      "valid_roc:[0.9154829545454546, 0.9261338123035473]\n",
      "train_roc_mean:0.9641356037425018\n",
      "valid_roc_mean:0.9208083834245009\n",
      "\n",
      "EPOCH:\t63\n",
      "train_roc:[0.9675602365475783, 0.9595005096839959]\n",
      "valid_roc:[0.9131155303030303, 0.9261338123035473]\n",
      "train_roc_mean:0.9635303731157872\n",
      "valid_roc_mean:0.9196246713032887\n",
      "\n",
      "EPOCH:\t64\n",
      "train_roc:[0.9684521773129369, 0.9610805300713557]\n",
      "valid_roc:[0.9169034090909091, 0.9279299506061968]\n",
      "train_roc_mean:0.9647663536921463\n",
      "valid_roc_mean:0.9224166798485529\n",
      "\n",
      "EPOCH:\t65\n",
      "train_roc:[0.9694907384780802, 0.9635779816513762]\n",
      "valid_roc:[0.9173768939393939, 0.9319712617871576]\n",
      "train_roc_mean:0.9665343600647283\n",
      "valid_roc_mean:0.9246740778632758\n",
      "\n",
      "EPOCH:\t66\n",
      "train_roc:[0.9658130101168076, 0.9570336391437309]\n",
      "valid_roc:[0.8813920454545454, 0.8875168387965874]\n",
      "train_roc_mean:0.9614233246302692\n",
      "valid_roc_mean:0.8844544421255665\n",
      "\n",
      "EPOCH:\t67\n",
      "train_roc:[0.95789550852842, 0.9464525993883793]\n",
      "valid_roc:[0.9367897727272727, 0.9526268522676246]\n",
      "train_roc_mean:0.9521740539583996\n",
      "valid_roc_mean:0.9447083124974487\n",
      "\n",
      "EPOCH:\t68\n",
      "train_roc:[0.9652265285176677, 0.9566666666666667]\n",
      "valid_roc:[0.9187973484848484, 0.9328693309384823]\n",
      "train_roc_mean:0.9609465975921672\n",
      "valid_roc_mean:0.9258333397116654\n",
      "\n",
      "EPOCH:\t69\n",
      "train_roc:[0.9663506182493524, 0.9589500509683996]\n",
      "valid_roc:[0.896780303030303, 0.9140098787606645]\n",
      "train_roc_mean:0.962650334608876\n",
      "valid_roc_mean:0.9053950908954838\n",
      "\n",
      "EPOCH:\t70\n",
      "train_roc:[0.9680734079468256, 0.9632721712538226]\n",
      "valid_roc:[0.9131155303030303, 0.9234396048495734]\n",
      "train_roc_mean:0.9656727896003241\n",
      "valid_roc_mean:0.9182775675763019\n",
      "\n",
      "EPOCH:\t71\n",
      "train_roc:[0.9702971506768976, 0.9645056065239552]\n",
      "valid_roc:[0.9150094696969697, 0.9225415356982488]\n",
      "train_roc_mean:0.9674013786004264\n",
      "valid_roc_mean:0.9187755026976092\n",
      "\n",
      "EPOCH:\t72\n",
      "train_roc:[0.9707125751429547, 0.9666156982670744]\n",
      "valid_roc:[0.8927556818181819, 0.9041311180960935]\n",
      "train_roc_mean:0.9686641367050146\n",
      "valid_roc_mean:0.8984433999571377\n",
      "\n",
      "EPOCH:\t73\n",
      "train_roc:[0.9721909975074532, 0.9675942915392456]\n",
      "valid_roc:[0.9105113636363636, 0.9225415356982487]\n",
      "train_roc_mean:0.9698926445233493\n",
      "valid_roc_mean:0.9165264496673062\n",
      "\n",
      "EPOCH:\t74\n",
      "train_roc:[0.9733639607057328, 0.9682059123343527]\n",
      "valid_roc:[0.9055397727272727, 0.9167040862146385]\n",
      "train_roc_mean:0.9707849365200427\n",
      "valid_roc_mean:0.9111219294709556\n",
      "\n",
      "EPOCH:\t75\n",
      "train_roc:[0.9725453301402669, 0.9683792048929662]\n",
      "valid_roc:[0.8955965909090909, 0.9108666367310283]\n",
      "train_roc_mean:0.9704622675166166\n",
      "valid_roc_mean:0.9032316138200596\n",
      "\n",
      "EPOCH:\t76\n",
      "train_roc:[0.9742436831044426, 0.9661875637104995]\n",
      "valid_roc:[0.917376893939394, 0.9252357431522227]\n",
      "train_roc_mean:0.970215623407471\n",
      "valid_roc_mean:0.9213063185458084\n",
      "\n",
      "EPOCH:\t77\n",
      "train_roc:[0.9741092810713065, 0.9678593272171254]\n",
      "valid_roc:[0.9131155303030303, 0.9207453973955995]\n",
      "train_roc_mean:0.9709843041442159\n",
      "valid_roc_mean:0.9169304638493149\n",
      "\n",
      "EPOCH:\t78\n",
      "train_roc:[0.967816822247202, 0.9580632008154943]\n",
      "valid_roc:[0.9192708333333333, 0.9185002245172879]\n",
      "train_roc_mean:0.9629400115313482\n",
      "valid_roc_mean:0.9188855289253106\n",
      "\n",
      "EPOCH:\t79\n",
      "train_roc:[0.9728019158398905, 0.9671763506625892]\n",
      "valid_roc:[0.9154829545454546, 0.9238886394252357]\n",
      "train_roc_mean:0.9699891332512398\n",
      "valid_roc_mean:0.9196857969853451\n",
      "\n",
      "EPOCH:\t80\n",
      "train_roc:[0.9737182933385464, 0.9696126401630989]\n",
      "valid_roc:[0.9171401515151515, 0.928378985181859]\n",
      "train_roc_mean:0.9716654667508227\n",
      "valid_roc_mean:0.9227595683485053\n",
      "\n",
      "EPOCH:\t81\n",
      "train_roc:[0.9729851913396217, 0.9688277268093782]\n",
      "valid_roc:[0.9135890151515151, 0.9211944319712618]\n",
      "train_roc_mean:0.9709064590745\n",
      "valid_roc_mean:0.9173917235613884\n",
      "\n",
      "EPOCH:\t82\n",
      "train_roc:[0.9752455891696399, 0.9699490316004078]\n",
      "valid_roc:[0.9022253787878788, 0.9090704984283791]\n",
      "train_roc_mean:0.9725973103850238\n",
      "valid_roc_mean:0.9056479386081289\n",
      "\n",
      "EPOCH:\t83\n",
      "train_roc:[0.9770294707003567, 0.9729357798165138]\n",
      "valid_roc:[0.9192708333333334, 0.9265828468792097]\n",
      "train_roc_mean:0.9749826252584353\n",
      "valid_roc_mean:0.9229268401062716\n",
      "\n",
      "EPOCH:\t84\n",
      "train_roc:[0.9768950686672205, 0.9732619775739042]\n",
      "valid_roc:[0.9197443181818182, 0.9328693309384823]\n",
      "train_roc_mean:0.9750785231205623\n",
      "valid_roc_mean:0.9263068245601502\n",
      "\n",
      "EPOCH:\t85\n",
      "train_roc:[0.978031376765554, 0.9735677879714577]\n",
      "valid_roc:[0.9064867424242424, 0.9095195330040413]\n",
      "train_roc_mean:0.9757995823685058\n",
      "valid_roc_mean:0.9080031377141419\n",
      "\n",
      "EPOCH:\t86\n",
      "train_roc:[0.9770294707003567, 0.9721712538226299]\n",
      "valid_roc:[0.9240056818181819, 0.9342164346654691]\n",
      "train_roc_mean:0.9746003622614934\n",
      "valid_roc_mean:0.9291110582418255\n",
      "\n",
      "EPOCH:\t87\n",
      "train_roc:[0.9767362299007868, 0.9720693170234455]\n",
      "valid_roc:[0.9064867424242425, 0.9086214638527167]\n",
      "train_roc_mean:0.9744027734621161\n",
      "valid_roc_mean:0.9075541031384796\n",
      "\n",
      "EPOCH:\t88\n",
      "train_roc:[0.9791799032305363, 0.9747706422018348]\n",
      "valid_roc:[0.904592803030303, 0.9108666367310283]\n",
      "train_roc_mean:0.9769752727161856\n",
      "valid_roc_mean:0.9077297198806656\n",
      "\n"
     ]
    }
   ],
   "source": [
    "best_param ={}\n",
    "best_param[\"roc_epoch\"] = 0\n",
    "best_param[\"loss_epoch\"] = 0\n",
    "best_param[\"valid_roc\"] = 0\n",
    "best_param[\"valid_loss\"] = 9e8\n",
    "\n",
    "for epoch in range(epochs):    \n",
    "    train_roc, train_loss = eval(model, train_df)\n",
    "    valid_roc, valid_loss = eval(model, valid_df)\n",
    "    train_roc_mean = np.array(train_roc).mean()\n",
    "    valid_roc_mean = np.array(valid_roc).mean()\n",
    "    \n",
    "#     tensorboard.add_scalars('ROC',{'train_roc':train_roc_mean,'valid_roc':valid_roc_mean},epoch)\n",
    "#     tensorboard.add_scalars('Losses',{'train_losses':train_loss,'valid_losses':valid_loss},epoch)\n",
    "\n",
    "    if valid_roc_mean > best_param[\"valid_roc\"]:\n",
    "        best_param[\"roc_epoch\"] = epoch\n",
    "        best_param[\"valid_roc\"] = valid_roc_mean\n",
    "        if valid_roc_mean > 0.85:\n",
    "             torch.save(model, 'saved_models/model_'+prefix_filename+'_'+start_time+'_'+str(epoch)+'.pt')             \n",
    "    if valid_loss < best_param[\"valid_loss\"]:\n",
    "        best_param[\"loss_epoch\"] = epoch\n",
    "        best_param[\"valid_loss\"] = valid_loss\n",
    "\n",
    "    print(\"EPOCH:\\t\"+str(epoch)+'\\n'\\\n",
    "        +\"train_roc\"+\":\"+str(train_roc)+'\\n'\\\n",
    "        +\"valid_roc\"+\":\"+str(valid_roc)+'\\n'\\\n",
    "        +\"train_roc_mean\"+\":\"+str(train_roc_mean)+'\\n'\\\n",
    "        +\"valid_roc_mean\"+\":\"+str(valid_roc_mean)+'\\n'\\\n",
    "        )\n",
    "    if (epoch - best_param[\"roc_epoch\"] >10) and (epoch - best_param[\"loss_epoch\"] >20):        \n",
    "        break\n",
    "        \n",
    "    train(model, train_df, optimizer, loss_function)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "best epoch:67\n",
      "test_roc:[0.9586206896551724, 0.9322916666666666]\n",
      "test_roc_mean: 0.9454561781609195\n"
     ]
    }
   ],
   "source": [
    "# evaluate model\n",
    "best_model = torch.load('saved_models/model_'+prefix_filename+'_'+start_time+'_'+str(best_param[\"roc_epoch\"])+'.pt')     \n",
    "\n",
    "best_model_dict = best_model.state_dict()\n",
    "best_model_wts = copy.deepcopy(best_model_dict)\n",
    "\n",
    "model.load_state_dict(best_model_wts)\n",
    "(best_model.align[0].weight == model.align[0].weight).all()\n",
    "test_roc, test_losses = eval(model, test_df)\n",
    "\n",
    "print(\"best epoch:\"+str(best_param[\"roc_epoch\"])\n",
    "      +\"\\n\"+\"test_roc:\"+str(test_roc)\n",
    "      +\"\\n\"+\"test_roc_mean:\",str(np.array(test_roc).mean())\n",
    "     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
